package mosdi.subcommands;

import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.StringTokenizer;

import mosdi.fa.DFAFactory;
import mosdi.fa.FiniteMemoryTextModel;
import mosdi.matching.HorspoolMatcher;
import mosdi.matching.Matcher;
import mosdi.paa.DAA;
import mosdi.paa.MinimizableDAA;
import mosdi.paa.MinimizedDAA;
import mosdi.paa.TextBasedPAA;
import mosdi.paa.apps.MatchCountDAA;
import mosdi.paa.apps.UnconstrainedCpGDAA;
import mosdi.util.Alphabet;
import mosdi.util.BitArray;
import mosdi.util.CodonAlphabet;
import mosdi.util.Combinatorics;
import mosdi.util.FalseDiscoveryRate;
import mosdi.util.Histogram;
import mosdi.util.IntArrayList;
import mosdi.util.Log;
import mosdi.util.NamedSequence;
import mosdi.util.ObjectWithPValue;
import mosdi.util.SequenceUtils;
import mosdi.util.iterators.SegmentIterator;
import mosdi.util.iterators.SegmentIterator.Segment;

public class PredictIslandsSubcommand extends Subcommand {
	@Override
	public String usage() {
		return
		super.usage()+" [options] <exons.fasta>\n" +
		"\n" +
		"  <exons.fasta> a file containing coding regions in FASTA format.\n" +
		"\n" +
		"Options:\n" +
		"  -O <order>: Order of background model to be estimated from input text (default: 1)\n" +
		"  -p <pseudocounts>: Pseudocounts for estimation of text model (default: 0.1)\n" +
		"  -f <FDR-threshold>: sets the level the false discovery rate is to be controlled at.\n" +
		"                      (default: 0.05)\n" +
		"  -c: use conditional p-value, i.e. do not compute probability of observing n or more CGs,\n" +
		"      but the probility of n or more CGs given that the segment starts and ends with CG\n" +
		"  -t: translate region's positions to genomic coordinates. If this option is given, the\n" +
		"      sequence names provided in the input file must be of the form\n" +
		"      \"chromosome startpos endpos strand\", e.g. \"8 145662611 145662709 +\".\n" +
		"      Coordinates are interpreted 1-based and endpos is assumed to be included in\n" +
		"      each interval. If the strand is '-', the corresponding sequence is assumed to be the\n" +
		"      reverse complement of the forward strand.\n" +
		"  -g <progress-interval>: if in verbose mode (-v), this option controls the intervals\n" +
		"                          (in percent) in which current progress is shown (default: 1.0)\n" +
		"  -H <histogram-filename>: write a histogram of p-values to the given filename\n" +
		"  -b <bin-count>: this option controls the number of bins for the histogram requested\n" +
		"                  by giving option -H (default: 1000).\n" +
		"  -C: Enable \"codon-awareness\". In this case, the given input sequence\n" +
		"      lengths must be multiples of 3 and are expected to start in frame 0,\n" +
		"      i.e. the first three letters can be interpreted as a codon.\n" +
		"  -2: Only count CpGs in frame 2, that is, CpG that span a codon boundary, and ignore\n" +
		"      CpGs in frame 1.";
	}

	@Override
	public String description() {
		return "Given sequences of coding regions, predicts CpG islands.";
	}

	@Override
	public String name() {
		return "predict";
	}

	public static class SegmentWithPValue extends SegmentIterator.Segment implements ObjectWithPValue {
		private double pvalue;
		public SegmentWithPValue(SegmentIterator iterator, SegmentIterator.Segment segment, double pvalue) {
			iterator.super(segment);
			this.pvalue = pvalue;
		}
		@Override
		public boolean equals(Object obj) {
			if (!(obj instanceof SegmentWithPValue)) return false;
			return compareTo((SegmentWithPValue)obj) == 0;
		}
		@Override
		public int compareTo(Segment o) {
			if (!(o instanceof SegmentWithPValue)) throw new IllegalArgumentException();
			if (pvalue<((SegmentWithPValue)o).pvalue) return -1;
			if (pvalue>((SegmentWithPValue)o).pvalue) return 1;
			return super.compareTo(o);
		}
		@Override
		public double getPValue() {
			return pvalue;
		}
		@Override
		public double getLogPValue() {
			return Math.log(pvalue);
		}	
	}
	
	@Override
	public int run(String[] args) {
		parseOptions(args, 1, "O:Cf:cp:tH:b:g:2");

		// Option dependencies
		impliedOptions("2", "C");

		// Mandatory arguments
		String exonFilename = getStringArgument(0);

		// Options
		int modelOrder = getNonNegativeIntOption("O", 1);
		boolean codonAware = getBooleanOption("C", false);
		double fdrThreshold = getRangedDoubleOption("f", 0, 1, 0.05);
		boolean conditionalPValues = getBooleanOption("c", false);
		double pseudoCounts = getRangedDoubleOption("p", 0.0, Double.POSITIVE_INFINITY, 0.1);
		boolean translateCoordinates = getBooleanOption("t", false);
		String histogramFilename = getStringOption("H", null);
		int histogramBinCount = getPositiveIntOption("b", 1000);
		double progressInterval = getRangedDoubleOption("g", 1e-10, 100.0, 1.0);
		boolean onlyFrame2CpGs = getBooleanOption("2", false);
		
		Alphabet alphabet = Alphabet.getDnaAlphabet();
		Log.printf(Log.Level.VERBOSE, "Reading file %s\n", exonFilename);
		List<NamedSequence> namedSequences = null;
		try {
			namedSequences = SequenceUtils.readFastaFile(exonFilename, alphabet, true);
		} catch (Exception e) {
			Log.errorln(e.getMessage());
			return 1;
		}
		List<int[]> sequences = SequenceUtils.sequenceList(namedSequences);
		Log.printf(Log.Level.VERBOSE, "Read %,d sequences\n", namedSequences.size());
		FiniteMemoryTextModel textModel = null;
		FiniteMemoryTextModel codonModel = null;
		List<int[]> codonSequences = null;
		if (codonAware) {
			codonSequences = SequenceUtils.codonSequenceList(namedSequences);
			codonModel = SequenceUtils.buildTextModelFromSequences(codonSequences, CodonAlphabet.size(), modelOrder, pseudoCounts);
		} else {
			textModel = SequenceUtils.buildTextModelFromSequences(sequences, alphabet.size(), modelOrder, pseudoCounts);
		}
		// -------- find and record all occurrences of CG -----------
		String pattern = "CG";
		Matcher m = new HorspoolMatcher(alphabet.size(), alphabet.buildIndexArray(pattern));
		long totalCGs = 0;
		// If codonAware, the following variables refer to CGs without coding constraints only.
		// Furthermore, positions refer to the sequences of codons, not nucleotides (i.e. positions
		// must be multiplied by 3)
		List<int[]> occurrences = new ArrayList<int[]>(sequences.size());
		// If codonAware, up to two CGs can end (and are counted) in one codon. To reflect that,
		// occurrences are weighted.
		List<int[]> occurrenceWeights = null;
		long totalLength = 0; // total length of input sequences
		int maxCount = 0; // largest number of CGs found in a single exon
		long totalCandidates = 0; // number of candidate regions for CpG islands
		long maxCandidateLength = 0; // length of longest such candidate
		long totalUnconstrainedCGs = 0;
		DAA daa;
		TextBasedPAA paa;
		if (codonAware) {
			for (int[] s : sequences) {
				totalCGs += m.findMatches(s);
			}
			occurrenceWeights = new ArrayList<int[]>(sequences.size());
			daa = new UnconstrainedCpGDAA(0, !onlyFrame2CpGs);
			for (int[] cs : codonSequences) {
				// find CpGs as counted by DAA, i.e. only count unconstrained CpGs
				totalLength += cs.length;
				int state = daa.getStartState();
				IntArrayList positions = new IntArrayList();
				IntArrayList weights = new IntArrayList();
				int matches = 0;
				for (int i=0; i<cs.length; ++i) {
					state = daa.getTransitionTarget(state, cs[i]);
					if (daa.getEmission(state)>0) {
						positions.add(i);
						weights.add(daa.getEmission(state));
						matches += daa.getEmission(state); 
					}
				}
				totalUnconstrainedCGs += matches;
				maxCount = Math.max(maxCount, matches);
				occurrences.add(positions.toIntArray());
				occurrenceWeights.add(weights.toIntArray());
				if (positions.size()>=2) {
					totalCandidates += Combinatorics.binomial(positions.size(), 2);
					maxCandidateLength = Math.max(maxCandidateLength, positions.get(positions.size()-1)-positions.get(0)+1);
				}
			}
			MinimizableDAA minimizableDaa = new UnconstrainedCpGDAA(maxCount);
			daa = new MinimizedDAA(minimizableDaa);
			paa = new TextBasedPAA(daa, codonModel);
			Log.printf(Log.Level.VERBOSE, "Total sequence length: %,d codons = %,d nt\n", totalLength, totalLength*3);
			Log.printf(Log.Level.VERBOSE, "Maximum number of occurrences of CGs in frame 1 or 2 in one sequence: %,d\n", maxCount);
			Log.printf(Log.Level.VERBOSE, "Total number of %ss: %,d (%4.1f%% of all dinucleotides), unconstrained CGs: %,d (%4.1f%% of all CGs)\n", pattern, totalCGs, 100.0*totalCGs/(totalLength*3-sequences.size()), totalUnconstrainedCGs, 100.0*totalUnconstrainedCGs/totalCGs);
			Log.printf(Log.Level.VERBOSE, "Segments to be examined: %,d\n", totalCandidates);
			Log.printf(Log.Level.VERBOSE, "Length of longest segment to be examined: %,d codons / %,d nt\n", maxCandidateLength, maxCandidateLength*3);
			Log.printf(Log.Level.VERBOSE, "States of text model / DAA / PAA: %d %d %d\n", codonModel.getStateCount(), daa.getStateCount(), paa.getStateCount());
		} else {
			for (int[] s : sequences) {
				totalLength += s.length;
				int[] matches = m.findAllMatchPositions(s);
				totalCGs += matches.length;
				occurrences.add(matches);
				maxCount = Math.max(maxCount, matches.length);
				if (matches.length>=2) {
					totalCandidates += Combinatorics.binomial(matches.length, 2);
					maxCandidateLength = Math.max(maxCandidateLength, matches[matches.length-1]+2-matches[0]);
				}
			}
			daa = new MatchCountDAA(DFAFactory.buildFromIupacPattern(pattern, false), maxCount);
			paa = new TextBasedPAA(daa, textModel);
			Log.printf(Log.Level.VERBOSE, "Total sequence length: %,d\n", totalLength);
			Log.printf(Log.Level.VERBOSE, "Maximum number of occurrences of %s in one sequence: %,d\n", pattern, maxCount);
			Log.printf(Log.Level.VERBOSE, "Total number of %ss: %,d (%4.1f%% of all dinucleotides)\n", pattern, totalCGs, 100.0*totalCGs/(totalLength-sequences.size()));
			Log.printf(Log.Level.VERBOSE, "Segments to be examined: %,d\n", totalCandidates);
			Log.printf(Log.Level.VERBOSE, "Length of longest segment to be examined: %,d\n", maxCandidateLength);
			Log.printf(Log.Level.VERBOSE, "States of text model / DAA / PAA: %d %d %d\n", textModel.getStateCount(), daa.getStateCount(), paa.getStateCount());
		}
		// joint distribution of states and values
		double[][] distribution;
		// number of steps done, i.e. length of random text "distribution" corresponds to
		int n;
		// a list of states with correspond to a match, i.e. being in such a state
		// means that a CG has just been encountered.
		List<Integer> acceptingStates = new ArrayList<Integer>();
		for (int state=0; state<paa.getStateCount(); ++state) {
			if (paa.getEmission(state)>0) acceptingStates.add(state);
		}
		// create initial probability distribution
		if (conditionalPValues) {
			// we initialize "distribution" to reflect that a CG "has already been read",
			// i.e. the state distribution is the equilibrium state distribution restricted
			// to accepting states. If we are codon-aware (and the underlying text model is therefore
			// not over the nucleotide alphabet but over the codon alphabet), this means that one
			// codon has already been read. In case of nucleotide, this means that two characters 
			// have already been read.
			n = codonAware?1:2;
			distribution = getConditionalInitialDistribution(paa, acceptingStates);
		} else {
			n = 0;
			if (codonAware) {
				distribution = getCodonInitialDistribution(paa);
			} else {
				distribution = paa.stateValueStartDistribution();
			}
		}
		ProgressPrinter progressPrinter = new ProgressPrinter(progressInterval/100.0,totalCandidates,codonAware?maxCandidateLength*3:maxCandidateLength);
		Histogram histogram = null;
		if (histogramFilename!=null) {
			histogram = new Histogram(0.0, 1.0, histogramBinCount);
		}
		long candidateCounter = 0;
		double[] rightCDF = null;
		int j = -1;
		SegmentIterator iterator;
		if (codonAware) {
			iterator = new SegmentIterator(CodonAlphabet.size(), codonSequences, 1, occurrences, occurrenceWeights);
		} else {
			iterator = new SegmentIterator(alphabet.size(), sequences, pattern.length(), occurrences);
		}
		List<SegmentWithPValue> islandCandidates = new ArrayList<SegmentWithPValue>();
		Log.startTimer();
		// MAIN LOOP: iterate over all segments starting and ending with CG, ordered by length
		while (iterator.hasNext()) {
			// fetch next segment and update distribution as long as necessary to
			// match the segment's length
			SegmentIterator.Segment segment = iterator.next();
			candidateCounter += 1;
			while (n<segment.length()) {
				progressPrinter.print(candidateCounter, codonAware?n*3:n);
				distribution = paa.updateStateValueDistribution(distribution);
				n += 1;
				rightCDF = null;
			}
			// Next step: derive right cumulative distribution (i.e. table of p-values) from
			// the two dimensional distribution
			if (conditionalPValues) {
				if (rightCDF==null) {
					rightCDF = conditionalRightCDF(distribution, acceptingStates);
				}
			} else {
				if (rightCDF==null) {
					rightCDF = new double[paa.getValueCount()];
					j = rightCDF.length - 1;
					// only sum over the largest value and compute the others as needed
					rightCDF[j] = computeValueProbability(distribution,j);
				}
				while (j>segment.getPatternCount()) {
					j -= 1;
					rightCDF[j] = computeValueProbability(distribution,j) + rightCDF[j+1];
				}
			}
			double pvalue = rightCDF[segment.getPatternCount()];
			if (histogram!=null) {
				histogram.add(Math.min(1.0,pvalue));
			}
			if (pvalue<=fdrThreshold) {
				islandCandidates.add(new SegmentWithPValue(iterator, segment, pvalue));
				if (Log.levelAtLeast(Log.Level.DEBUG)) {
					Log.printf(Log.Level.DEBUG, "Segment: %d %d %d %d %e \"%s\"\n", segment.getStartPosition(), segment.getEndPosition(), segment.length(), segment.getPatternCount(), pvalue, namedSequences.get(segment.getStringIndex()).getName());
				}
			}
			//Log.printf(Log.Level.STANDARD, "%s %d %d %f\n", segment.toString(), segment.getPatternCount(), segment.length(), pvalue);
			progressPrinter.print(candidateCounter, codonAware?n*3:n);
		}
		if (candidateCounter!=totalCandidates) throw new IllegalStateException("This is a BUG!");
		Log.printf(Log.Level.VERBOSE, "Examined %,d candidates for CpG islands, of which %,d have a p-value below %f\n", totalCandidates, islandCandidates.size(), fdrThreshold);
		Log.restartTimer("Computation of p-values for all candidates");
		if (histogram!=null) {
			try {
				PrintWriter histogramWriter = new PrintWriter(new FileWriter(histogramFilename));
				for (Histogram.Bin b : histogram) {
					histogramWriter.println(String.format("%f %d", b.getIntervalMiddle(), b.getCount()));
				}
				histogramWriter.close();
			} catch (IOException e) {
				Log.errorln("Error writing file "+histogramFilename+": "+e.getMessage());
				return 1;
			}
			Log.restartTimer("Writing histogram");
		}
		Collections.sort(islandCandidates);
		int islandsToRetain = FalseDiscoveryRate.control(islandCandidates, fdrThreshold, totalCandidates); 
		Log.printf(Log.Level.VERBOSE, "%,d candidates remain after controlling FDR at %f\n", islandsToRetain, fdrThreshold);
		Log.restartTimer("Performing FDR control");
		List<SegmentWithPValue> selectedIslands = greedyChoice(islandCandidates, islandsToRetain, occurrences);
		Log.printf(Log.Level.VERBOSE, "%,d candidates remain after greedy algorithm\n", selectedIslands.size());
		Log.stopTimer("Running greedy algorithm");
		for (SegmentWithPValue s : selectedIslands) {
			int start;
			int end;
			if (codonAware) {
				start = s.getStartPosition() * 3;
				end = s.getEndPosition() * 3;
			} else {
				start = s.getStartPosition();
				end = s.getEndPosition();
			}
			String sequenceName = namedSequences.get(s.getStringIndex()).getName();
			Log.printf(Log.Level.STANDARD, ">> %d %d %d %d %e \"%s\"", start, end, end-start, s.getPatternCount(), s.getPValue(), sequenceName);
			if (translateCoordinates) {
				try {
					StringTokenizer st = new StringTokenizer(sequenceName, " ", false);
					String chromosome = st.nextToken();
					int genomicStartPos = Integer.parseInt(st.nextToken());
					int genomicEndPos = Integer.parseInt(st.nextToken());
					boolean strand = parseStrand(st.nextToken());
					if (strand) {
						Log.printf(Log.Level.STANDARD, " %s %d %d", chromosome, genomicStartPos+start, genomicStartPos+end-1);
					} else {
						Log.printf(Log.Level.STANDARD, " %s %d %d", chromosome, genomicEndPos-end+1, genomicEndPos-start);
					}
				} catch (NoSuchElementException e) {
					Log.errorln("Could not parse sequence name: \""+sequenceName+"\"");
					return 1;
				} catch (IllegalArgumentException e) {
					Log.errorln("Could not parse sequence name: \""+sequenceName+"\"");
					return 1;
				}
			}
			Log.println(Log.Level.STANDARD,"");
		}
		return 0;
	}

	public static boolean parseStrand(String strand) {
		if (strand.equals("+")) return true;
		if (strand.equals("-")) return false;
		throw new IllegalArgumentException();
	}
		
	/** Helper class to print progress in terms of examined segments and current segment length. */
	public static class ProgressPrinter {
		private double increment;
		private long totalCandidates;
		private long maxLength;
		private int n;
		private int m;
		private String formatString;
		private ProgressPrinter(double increment, long totalCandidates, long maxLength) {
			this.increment = increment;
			this.totalCandidates = totalCandidates;
			this.maxLength = maxLength;
			n = 0;
			m = 0;
			int len1 = (int)Math.ceil(Math.log10(totalCandidates));
			int len2 = (int)Math.ceil(Math.log10(maxLength));
			formatString = "Progress: Segments: %" + len1 + "d / %" + len1 + "d (%5.1f%%), Length: %" + len2 + "d / %" + len2 + "d (%5.1f%%)\n";
		}
		private void print(long candidatesDone, long lengthDone) {
			if (!Log.levelAtLeast(Log.Level.VERBOSE)) return;
			int nNew = (int)(1.0 / increment * candidatesDone / totalCandidates);
			int mNew = (int)(1.0 / increment * lengthDone / maxLength);
			if ((nNew>n) || (mNew>m)) {
				Log.printf(Log.Level.VERBOSE, formatString , candidatesDone, totalCandidates, 100.0*candidatesDone/totalCandidates, lengthDone, maxLength, 100.0*lengthDone/maxLength);
				n = nNew;
				m = mNew;
			}
		}
	}
	
	/** Returns the initial state/value distribution of a PAA conditional on being in one
	 *  of the given states. */
	private double[][] getConditionalInitialDistribution(TextBasedPAA paa,	List<Integer> statesCondition) {
		double[][] distribution;
		double sum = 0.0;
		double[] stateEq = paa.convergeToStateEquilibrium(1e-14, 1000);
		for (int state : statesCondition) {
			sum += stateEq[state];
		}
		distribution = new double[paa.getStateCount()][paa.getValueCount()];
		for (int state : statesCondition) {
			distribution[state][paa.performOperation(state, paa.getStartValue(), paa.getEmission(state))] = stateEq[state] / sum;
		}
		return distribution;
	}

	/** Use different initial distribution then returned by paa.stateValueStartDistribution(), which would 
	 *  have probability one of being in the DAAs start state. We want the DAA start distribution to be
	 *  in equilibrium (to allow matches to "stretch" into the first codon). 
	 **/
	private double[][] getCodonInitialDistribution(TextBasedPAA paa) {
		double[][] distribution;
		double[] stateEq = paa.convergeToStateEquilibrium(1e-14, 1000);
		distribution = new double[paa.getStateCount()][paa.getValueCount()];
		for (int state=0; state<paa.getStateCount(); ++state) {
			distribution[state][paa.getStartValue()] = stateEq[state];
		}
		return distribution;
	}

	/** Returns the right tail distribution of values given that process is in one of the
	 *  given states. */
	final private double[] conditionalRightCDF(double[][] distribution, List<Integer> statesCondition) {
		double[] rightCDF;
		// we have to condition on being in an accepting state 
		rightCDF = new double[distribution[0].length];
		double sum = 0.0;
		for (int value=0; value<rightCDF.length; ++value) {
			double p = 0.0;
			for (int state : statesCondition) {
				p += distribution[state][value];
			}
			rightCDF[value] = p;
			sum += p;
		}
		// normalize to 1.0 to get conditional probability distribution
		for (int i=0; i<rightCDF.length; ++i) rightCDF[i] /= sum;
		// sum up to get right-cumulative distribution
		for (int i=rightCDF.length-2; i>=0; --i) rightCDF[i] += rightCDF[i+1];
		return rightCDF;
	}

	/** In order of given list, choose island candidate unless it overlaps a previously chosen candidate. */
	final private List<SegmentWithPValue> greedyChoice(List<SegmentWithPValue> islandCandidates, int islandsToRetain, List<int[]> occurrences) {
		// Create BitArray with one bit for every CpG. If CpG is covered by already
		// selected island, then the corresponding bit is set to one.
		// Asymptoticly, there are faster algorithms for the job, but probably this 
		// is fast enough for our application
		BitArray[] coveredCGs = new BitArray[occurrences.size()];
		for (int i=0; i<occurrences.size(); ++i) {
			coveredCGs[i] = new BitArray(occurrences.get(i).length);
		}
		List<SegmentWithPValue> selectedIslands = new ArrayList<SegmentWithPValue>();
		for (int i=0; i<islandsToRetain; ++i) {
			SegmentWithPValue s = islandCandidates.get(i);
			int from = s.getStartIndex();
			int to = s.getEndIndex()+1;
			// skip if overlapping with already selected island
			if (!coveredCGs[s.getStringIndex()].allZero(from, to)) continue;
			selectedIslands.add(s);
			coveredCGs[s.getStringIndex()].setRange(from, to, 1);
		}
		return selectedIslands;
	}

	/** Sums probability for a given value over all states. */
	final private double computeValueProbability(double[][] distribution, int value) {
		double p = 0.0;
		for (int state=0; state<distribution.length; ++state) {
			p += distribution[state][value];
		}
		return p;
	}

}
